import matplotlib.pyplot as plt
import numpy as np
import os 
from copy import deepcopy

Plot_group = 2 # [1,2]: 1 is ViT, 2 is ResNet

top_dir = os.getcwd()
read_path_dict_all = {

    # Plot1 VIT
    1:{
        'CIFAR20_SEQ':'experiments/SEQ_CIFAR100/2023-04-19-17-55-30/train.log',
        'CIFAR20_REPLAY':'experiments/REPLAY_CIFAR100/2023-04-19-17-55-28/train.log',
        'CIFAR20_MTL':'experiments/MTL_CIFAR100/2023-04-25-01-13-21/train.log',
    },

    # Plot2 ResNet
    2:{
        'CIFAR20_RESNET_SEQ':'experiments/SEQ_CIFAR100/2023-04-21-22-22-45/train.log', 
        'CIFAR20_RESNET_REPLAY':'experiments/REPLAY_CIFAR100/2023-04-21-22-22-47/train.log', 
        'CIFAR20_RESNET_MTL':'experiments/MTL_CIFAR100/2023-04-25-12-09-55/train.log',
    }
    
}
read_path_dict = read_path_dict_all[Plot_group]
read_path_dict = {k:os.path.join(top_dir,v) for k,v in read_path_dict.items()}
save_path = os.path.join(top_dir,'plots/probing_result/')
if not os.path.isdir(save_path):
    os.makedirs(save_path)

total_task = 20
probe_result_acc = {k:[[] for _ in range(total_task)] for k,v in read_path_dict.items()}
probe_aver_acc = {k:[] for k,v in read_path_dict.items()}
probe_str = "Probe result = "

test_result_acc = {k:[[] for _ in range(total_task)] for k,v in read_path_dict.items()}
test_aver_acc = {k:[] for k,v in read_path_dict.items()}
test_str = "test result = " 

task_init_x = {k:[] for k,v in read_path_dict.items()}

task = 'IC' # 'NER' # "IC", "TC"

for k,v in read_path_dict.items():
    with open(v) as f:
        for line in f.readlines():
            line=line.strip()
            if probe_str in line:
                bg_idx = line.find(probe_str) + len(probe_str)
                probe_result = eval(line[bg_idx:])
                for task_id in range(total_task):
                    if task == 'NER':
                        if 'Result_test_maf1_%d'%(task_id) in probe_result.keys():
                            probe_result_acc[k][task_id].append(probe_result['Result_test_maf1_%d'%(task_id)])
                        else:
                            probe_result_acc[k][task_id].append(-1)
                    else:
                        if 'Result_test_acc_%d'%(task_id) in probe_result.keys():
                            probe_result_acc[k][task_id].append(probe_result['Result_test_acc_%d'%(task_id)])
                        else:
                            probe_result_acc[k][task_id].append(-1)
                # Average
                probe_aver_acc[k].append(probe_result['Result_test_mean_acc'])

            if test_str in line:
                bg_idx = line.find(test_str) + len(test_str)
                test_result = eval(line[bg_idx:])
                for task_id in range(total_task):
                    if task == 'NER':
                        if 'Result_test_maf1_%d'%(task_id) in test_result.keys():
                            test_result_acc[k][task_id].append(test_result['Result_test_maf1_%d'%(task_id)])
                        else:
                            test_result_acc[k][task_id].append(-1)
                    else:
                        if 'Result_test_acc_%d'%(task_id) in test_result.keys():
                            test_result_acc[k][task_id].append(test_result['Result_test_acc_%d'%(task_id)])
                        else:
                            test_result_acc[k][task_id].append(-1)
                # Average
                test_aver_acc[k].append(test_result['Result_test_mean_acc'])

    
    for task_id in range(total_task):
        y = np.array(probe_result_acc[k][task_id])
        x = np.where(y>=0)[0]
        task_init_x[k].append(x[0])

# # ================================== Figure 1: the learning curve ==================================

# color_list = plt.get_cmap('tab20',20)
# init_acc_classifier = 1/(100/total_task)*100 # CIFAR100 (%)
# for k,v in read_path_dict.items():

#     for task_id in range(total_task):
#         y = np.array(probe_result_acc[k][task_id])
#         x = np.where(y>=0)[0]
#         line1, = plt.plot(x,y[x],color=color_list(0),linestyle='-')

#     for task_id in range(total_task):
#         y = np.array(test_result_acc[k][task_id])
#         x = np.where(y>=0)[0]
#         adjust_y = deepcopy(y[x])
#         adjust_y[0] = init_acc_classifier
#         line2, = plt.plot(x,adjust_y,color=color_list(1),linestyle='--')

#     line3, = plt.plot(list(range(len(probe_aver_acc[k]))),probe_aver_acc[k],linestyle='-',color='k',linewidth=2)
#     line4, = plt.plot(list(range(len(test_aver_acc[k]))),test_aver_acc[k],linestyle='--',color='k',linewidth=2)

#     if k in ['CIFAR20_RESNET_REPLAY','CIFAR20_REPLAY']:
#         plt.legend(handles=[line1,line2,line3,line4],
#                    labels=['Probing Accuracy of Each Task','Origin Accuracy of Each Task',
#                            'Average Probing Accuracy','Average Origin Accuracy'],fontsize=12)

#     # plt.title('Model %s'%(k))
#     plt.xlabel('Task',fontsize=15)    
#     plt.ylabel('ACC (%)',fontsize=15)    
#     plt.ylim(-3,103)
#     plt.yticks(fontsize=12)
#     plt.xticks([x for x in task_init_x[k]],['%d'%(i+1) for i in range(total_task)],fontsize=12)
#     plt.savefig(os.path.join(save_path,'acc_%s.pdf'%(k)),dpi=1200,bbox_inches='tight')
#     plt.clf()

# # ==============================================================================================

# # ================================== Figure 2: the acc/fgt and distance ==================================

result_dict = {k:{'fgt_probe':0,'fgt_origin':0,'acc_probe':0,'acc_origin':0}for k,v in read_path_dict.items()}

for k,v in read_path_dict.items():

    forgetting_probe = np.mean([np.max(probe_result_acc[k][i])-probe_result_acc[k][i][-1] for i in range(total_task)])
    forgetting_test = np.mean([np.max(test_result_acc[k][i])-test_result_acc[k][i][-1] for i in range(total_task)])
 
    acc_probe = np.mean([probe_result_acc[k][i][-1] for i in range(total_task)])
    acc_origin = np.mean([test_result_acc[k][i][-1] for i in range(total_task)])

    result_dict[k]['fgt_probe'] = -1*forgetting_probe
    result_dict[k]['fgt_origin'] = -1*forgetting_test
    result_dict[k]['fgt_diff'] = result_dict[k]['fgt_probe'] - result_dict[k]['fgt_origin']
    result_dict[k]['acc_probe'] = acc_probe
    result_dict[k]['acc_origin'] = acc_origin
    result_dict[k]['acc_diff'] = result_dict[k]['acc_probe'] - result_dict[k]['acc_origin']

x = list(range(len(read_path_dict.items())))
width = 0.2
num_measure = 3
color_list = plt.get_cmap('Set1',4)

if Plot_group == 1:
    # Plot1 VIT
    method_list = ['CIFAR20_SEQ','CIFAR20_REPLAY','CIFAR20_MTL']
    method_name = ['ViT+SEQ','ViT+REPLAY','ViT+MTL']
    distance_y = [0.8312,0.5893,0.1054] # vit
elif Plot_group == 2:
    # Plot2 ResNet
    method_list = ['CIFAR20_RESNET_SEQ','CIFAR20_RESNET_REPLAY','CIFAR20_RESNET_MTL']
    method_name = ['Res.+SEQ','Res.+REPLAY','Res.+MTL']
    distance_y = [0.8270,0.2741,0.1510] # resnet

# # -------------------------------------------------------------------------------------------------
plt.figure(figsize=(8,5))
ax1 = plt.gca()
local_x_1 = [tmp_x-width*(0.5*num_measure-0.5)+0*width for tmp_x in x]
y_1 = [result_dict[k]['fgt_probe'] for k in method_list]
bar1 = ax1.bar(local_x_1, y_1, width=width,label='Fgt_Probe',color=color_list(0))

local_x_2 = [tmp_x-width*(0.5*num_measure-0.5)+1*width for tmp_x in x]
y_2 = [result_dict[k]['fgt_origin'] for k in method_list]
bar2 = ax1.bar(local_x_2, y_2, width=width,label='Fgt_Origin',color=color_list(0),alpha=0.5)


ax1.set_ylabel('Forgetting (%)',fontsize=15)
ax1.spines['bottom'].set_position(('data', 0))
ax1.set_xticks(x,list(method_name),position=(0,-130),fontsize=15)
ax1.set_yticks([0,-25,-50,-75,-100],[0,-25,-50,-75,-100],fontsize=15)
ax1.set_ylim(bottom=-130,top=100)

ax2 = ax1.twinx()
local_x_3 = [tmp_x-width*(0.5*num_measure-0.5)+2*width for tmp_x in x]
y_3 = distance_y
bar3 = ax2.bar(local_x_3, y_3, width=width,label='Feat_Embed_Dist',color=color_list(2))

for _x, _y in zip(local_x_1, y_1):
    ax1.text(_x,_y-10,'%.1f'%(_y),ha='center', va='bottom',fontsize=12)
for _x, _y in zip(local_x_2, y_2):
    ax1.text(_x,_y-10,'%.1f'%(_y),ha='center', va='bottom',fontsize=12)
for _x, _y in zip(local_x_3, y_3):
    ax2.text(_x,_y,'%.2f'%(_y),ha='center', va='bottom',fontsize=12)

ax2.set_yticks([0.0,0.2,0.4,0.6,0.8,1.0],[0.0,0.2,0.4,0.6,0.8,1.0],fontsize=15)
ax2.set_ylim(bottom=-1.3,top=1.0)
ax2.set_ylabel('Feature Embedding Distance',fontsize=15)

plt.legend(handles=[bar1,bar2,bar3],labels=['Probing Forgetting','Original Forgetting','Feat. Embed. Distance'],loc='lower right',fontsize=15)
if Plot_group == 1:
    # Plot1 VIT
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Fgt_vit.pdf'),dpi=1200,bbox_inches='tight')
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Fgt_vit.png'),dpi=1200,bbox_inches='tight')
elif Plot_group == 2:
# Plot2 ResNet
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Fgt_resnet.pdf'),dpi=1200,bbox_inches='tight')
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Fgt_resnet.png'),dpi=1200,bbox_inches='tight')
plt.clf()
# # -------------------------------------------------------------------------------------------------

# # -------------------------------------------------------------------------------------------------
plt.figure(figsize=(8,5))
ax1 = plt.gca()
local_x_1 = [tmp_x-width*(0.5*num_measure-0.5)+0*width for tmp_x in x]
y_1 = [result_dict[k]['acc_probe'] for k in method_list]
bar1 = ax1.bar(local_x_1, y_1, width=width,label='Acc_Probe',color=color_list(1))

local_x_2 = [tmp_x-width*(0.5*num_measure-0.5)+1*width for tmp_x in x]
y_2 = [result_dict[k]['acc_origin'] for k in method_list]
bar2 = ax1.bar(local_x_2, y_2, width=width,label='Acc_Origin',color=color_list(1),alpha=0.5)


ax1.set_ylabel('Average Accuracy (%)',fontsize=15)
ax1.set_xticks(x,list(method_name),fontsize=15)
ax1.set_yticks([0,25,50,75,100],[0,25,50,75,100],fontsize=15)
ax1.set_ylim(bottom=0,top=145)

ax2 = ax1.twinx()
local_x_3 = [tmp_x-width*(0.5*num_measure-0.5)+2*width for tmp_x in x]
y_3 = distance_y
bar3 = ax2.bar(local_x_3, y_3, width=width,label='Feat_Emb_Dist',color=color_list(2))

for _x, _y in zip(local_x_1, y_1):
    ax1.text(_x,_y,'%.1f'%(_y),ha='center', va='bottom',fontsize=12)

for _x, _y in zip(local_x_2, y_2):
    ax1.text(_x,_y,'%.1f'%(_y),ha='center', va='bottom',fontsize=12)

for _x, _y in zip(local_x_3, y_3):
    ax2.text(_x,_y,'%.2f'%(_y),ha='center', va='bottom',fontsize=12)

ax2.set_yticks([0.0,0.2,0.4,0.6,0.8,1.0],[0.0,0.2,0.4,0.6,0.8,1.0],fontsize=15)
ax2.set_ylim(bottom=-0,top=1.45)
ax2.set_ylabel('Feature Embedding Distance',fontsize=15)

plt.legend(handles=[bar1,bar2,bar3],labels=['Probing Accuracy','Original Accuracy','Feat. Embed. Distance'],fontsize=15)

if Plot_group == 1:
    # Plot1 VIT
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Acc_vit.pdf'),dpi=1200,bbox_inches='tight')
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Acc_vit.png'),dpi=1200,bbox_inches='tight')
elif Plot_group == 2:
    # Plot2 ResNet
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Acc_resnet.pdf'),dpi=1200,bbox_inches='tight')
    plt.savefig(os.path.join(save_path,'CIFAR20_EmbFeat_Acc_resnet.png'),dpi=1200,bbox_inches='tight')

plt.clf()
# -------------------------------------------------------------------------------------------------